import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import adept_envs # type: ignore
import gym
import math
import cv2
import numpy as np
from PIL import Image
import os
import torchvision.transforms as T
from vip import load_vip
import pickle
import time
torch.set_printoptions(edgeitems=10, linewidth=500)
# Basic global variables
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device is", device, flush=True)
vip = load_vip()
vip.eval()
vip = vip.to(device)
transforms = T.Compose([T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()]) # ToTensor() divides by 255


class Robotic_Environment: # This will only be used for evaluating a trained model
    # Creates the entire robotic environment on pybullet
    def __init__(self, video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval):

        env = gym.make('kitchen_relax-v1')
        self.env = env.env
        self.video_frames = [] # These are the frames of video saved for evaluation
        self.video_resolution = video_resolution
        self.env.reset()
        if(gaussian_noise):
            mean = 0  # Mean of the Gaussian noise
            std_dev =0.03  # Standard deviation of the Gaussian noise
            self.env.sim.data.qpos[:] = reset_information[0] * (1 + np.random.normal(mean, std_dev, reset_information[0].shape))
            self.env.sim.data.qvel[:] = reset_information[1] * (1 + np.random.normal(mean, std_dev, reset_information[1].shape))
        else:
            self.env.sim.data.qpos[:] = reset_information[0]
            self.env.sim.data.qvel[:] = reset_information[1]
        self.env.sim.forward() # The environment is setup

    def step(self, action):
        
        self.env.step(np.array(action)) # Execute some action
        curr_frame = self.env.render(mode='rgb_array') # Capture image
        rgb_array = np.array(curr_frame)
        rgb_array = Image.fromarray(rgb_array)
        rgb_array = np.array(rgb_array)
        bgr_array = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
        bgr_array = cv2.resize(bgr_array, self.video_resolution)
        self.video_frames.append(bgr_array)

    def get_current_state(self, space): # This is the state in the format specified as input
        if(space == "joint_space"):
            return (self.env._get_obs()).tolist()
        elif(space == "both"):
            # Image embedding
            curr_frame = self.env.render(mode='rgb_array') # Capture image
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            current_state = subgoal_embedding.cpu().tolist()[0]
            # Joint space
            non_fixed_current_joint_state = (self.env._get_obs()).tolist()
            # Concatenate image + joint
            return current_state + non_fixed_current_joint_state
        elif(space == "image_embedding"):
            curr_frame = self.env.render(mode='rgb_array') # Capture image
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            return subgoal_embedding.cpu().tolist()[0]

    def save_video(self, video_filename, video_filename_in_hand):
        video_fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        video_out = cv2.VideoWriter(video_filename, video_fourcc, 30.0, self.video_resolution)
        for i in range(0 , len(self.video_frames),4 ):
            frame = self.video_frames[i]
            video_out.write(frame)
        video_out.release()

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):  # Assuming 5000 is the maximum length of any trajectory
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, 1, d_model)  # Shape: (max_len, 1, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # Shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # Shape: (d_model/2,)
        pe[:, 0, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pe[:, 0, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices
        self.register_buffer('pe', pe)  # Register as buffer to avoid updating during training

    def forward(self, x, timestamps):

        timestamps = timestamps.long().to(self.pe.device)  # Shape: (batch_size, context_length)
        positions = self.pe[timestamps]  # Shape: (batch_size, context_length, 1 , d_model)
        positions = positions.squeeze(2)  # Shape: (batch_size, context_length, d_model)
        x = x + positions  # Broadcasting addition (batch_size, context_length, d_model) + (batch_size, context_length, d_model)
        return self.dropout(x)

class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, activation="relu", dropout=0.0):
        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = nn.ReLU() if activation == "relu" else nn.GELU()
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self._initialize_weights()

    def _initialize_weights(self):
        """Apply proper initialization to layers."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # Xavier uniform for linear layers
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.MultiheadAttention):
                for param in m.parameters():
                    if param.dim() > 1:
                        nn.init.xavier_uniform_(param)  # Xavier uniform for attention layers
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)  # LayerNorm weight initialized to 1
                nn.init.zeros_(m.bias)  # Bias initialized to 0

    def forward(self, src, attention_mask):
        attn_output, _ = self.self_attn(src, src, src, attn_mask=attention_mask)
        src = self.norm1(src + attn_output)  # Residual connection
        ff_output = self.linear2(self.activation(self.linear1(src)))
        src = self.norm2(src + ff_output)  # Residual connection
        return src

class TransformerPolicy_Custom(nn.Module):
    def __init__(self, state_dimension, output_dimension, context_length, nhead, num_encoder_layers, dim_feedforward, dropout, activation="relu"):
        super(TransformerPolicy_Custom, self).__init__()
        self.pos_encoder = PositionalEncoding(state_dimension, dropout)
        self.transformer_encoder_layers = nn.ModuleList([
            CustomTransformerEncoderLayer(
                d_model=state_dimension,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                activation=activation,
                dropout=dropout
            ) for _ in range(num_encoder_layers)
        ])
        self.output_layer = nn.Linear(state_dimension * context_length, output_dimension)
        self._initialize_weights()

    def _initialize_weights(self):
        """Apply initialization to output layer."""
        nn.init.xavier_uniform_(self.output_layer.weight)
        if self.output_layer.bias is not None:
            nn.init.zeros_(self.output_layer.bias)

    def forward(self, current_state, timestamps):
        current_state = self.pos_encoder(current_state, timestamps)  
        current_state = current_state.permute(1, 0, 2)  # Shape: (context_length, batch_size, state_dimension)
        device = current_state.device
        attention_mask = nn.Transformer.generate_square_subsequent_mask(current_state.shape[0], device)
        
        for encoder_layer in self.transformer_encoder_layers:
            current_state = encoder_layer(current_state, attention_mask)

        current_state = current_state.permute(1, 0, 2)  
        current_state = current_state.reshape(current_state.shape[0], -1)
        actions = self.output_layer(current_state)
        return actions

class TrajectoryDataset(Dataset): # Dateset for Behavioural cloning
    def __init__(self, Trajectory_directories, base_directory , state_space, subgoal_directory_path, camera_number, action_chunking,subgoal_conditioned, subgoal_frames_delta, subgoal_change_format , context_length):
        self.Trajectory_directories = Trajectory_directories # List of all the directories 
        self.base_directory= base_directory
        self.state_space = state_space
        self.subgoal_directory_path = subgoal_directory_path
        self.camera_number = camera_number
        self.action_chunking = action_chunking
        self.subgoal_conditioned = subgoal_conditioned
        self.subgoal_frames_delta = subgoal_frames_delta
        self.subgoal_change_format = subgoal_change_format
        self.context_length = context_length
        self.trajectories = self._load_trajectories()

    def _read_csv(self, file_path, directory):
        with open(file_path, 'rb') as f: # Read the pickel file
            data_dict = pickle.load(f)

        observations = data_dict['observations']  # Shape: (244, 60)
        actions = data_dict['actions']  # Shape: (244, 9)
        data = []
        for i in range(observations.shape[0]):
            observation = observations[i]
            action = actions[i]
            row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
            data.append(row)
        # For every state append the embedding of image to the row
        video_path = f"{self.base_directory}/{directory}/camera_{camera_number}.avi"
        cap = cv2.VideoCapture(video_path)
        for i in range(len(data)):
            ret, frame = cap.read()  # ret is a boolean indicating success, frame is the image
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            data[i].extend(subgoal_embedding.cpu().tolist()[0])
        cap.release()
        # data is 60(joint) + 3(task/buffer) + 1(time) + 9(action) + 1024(image embedding)
        if (self.subgoal_conditioned):
            subgoals_directory = f"{self.base_directory}/{directory}/{self.subgoal_directory_path}"
            files = os.listdir(subgoals_directory)
            png_files = [f for f in files if f.endswith('.png')]
            numbers = [int(f.replace('.png', '')) for f in png_files]
            list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
            list_of_subgoals.pop(0)
        history_length = self.context_length
        if(self.subgoal_conditioned):
            history_length-=1 # In this case 1 is for subgoal

        output = [] # list of subgoals , state, action, attentionmask , timestamps
        for i in range(len(data)):
            append_to_output = []
            state = []
            timestamps = []
            for j in range(history_length - i -1):
                state.append(self.state_embedding(data[0]))
                timestamps.append(0)
            for j in range( max(0,i+1-history_length) , i+1):
                state.append(self.state_embedding(data[j]))
                timestamps.append(j)

            if(self.subgoal_conditioned):
                subgoal_index = self.get_subgoal_index(i , list_of_subgoals)
                state.append(self.state_embedding(data[subgoal_index]))
                timestamps.append(timestamps[-1]+1)

            append_to_output.append(state) # state
            action = []
            if(self.subgoal_conditioned and self.subgoal_change_format == "same_network"):
                if( (abs(i - subgoal_index) <= subgoal_frames_delta )): # or (abs(i - last_subgoal_index) <= subgoal_frames_delta)  ):
                    action.append(1)
                else:
                    action.append(0)

            action+= data[i][64:73]
            for j in range(i+1 , i+self.action_chunking):
                if (j >= len(data)):
                    action+= [0.,0.,0.,0.,0.,0.,0.,0.,0.]
                else:
                    action+= data[j][64:73]
            append_to_output.append(action) # action
            append_to_output.append(timestamps) # Timestamp
            output.append(append_to_output)

        return output

    def state_embedding(self , data):
        if(self.state_space == "joint_space"):
            state = data[0:60]
        elif(self.state_space == "both"):
            state = data[73:1097] + data[0:60]
        elif(self.state_space == "image_embedding"):
            state = data[73:1097]
        return state
    
    def get_subgoal_index(self, ind, list_of_subgoals): # This gives the subgoal number for some index
        for i in list_of_subgoals:
            if(i>= ind):
                return i

    def _load_trajectories(self):
        trajectories = []
        for directory in self.Trajectory_directories:
            base_directory = f"{self.base_directory}/{directory}"
            file_path = f"{base_directory}/data.pkl"
            trajectory_data = self._read_csv(file_path, directory)
            for i in range(len(trajectory_data)):
                subgoal_state_action_pair = trajectory_data[i]
                trajectories.append(subgoal_state_action_pair)
        return trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx): # This gives the exact subgoal, state, action , mask , timestamps tuple
        trajectory_data = self.trajectories[idx]
        state = torch.tensor(trajectory_data[0] ,  dtype=torch.float32)
        action = torch.tensor(trajectory_data[1] ,  dtype=torch.float32)
        timestamps = torch.tensor(trajectory_data[2], dtype=torch.float32)
        return (state , action,  timestamps)

# Define custom loss function
class BCEWithLogitsLoss_MSELoss(nn.Module): # The first bit detects subgoal change and last 8 bits detect joint movement
    def __init__(self):
        super(BCEWithLogitsLoss_MSELoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(6.0, dtype=torch.float32))
        self.mse_loss = nn.MSELoss()

    def forward(self, predicted_value, target):
        bce = self.bce_loss(predicted_value[:, 0], target[:, 0])
        mse = self.mse_loss(predicted_value[:, 1:], target[:, 1:])
        total_loss = bce + 100*mse # Combine the two losses (you can adjust the weights if needed)
        return total_loss

def find_largest_number(file_path): # Takes in a directory which contains files of the form number.mp4 and find the largest numbered file inside it
    with open(file_path, 'r') as file:
        lines = file.readlines()
    last_line = lines[-1].strip()
    first_word = last_line.split()[0]
    first_word_int = int(first_word)
    return first_word_int

if __name__ == '__main__':
    # Parameters
    train = True
    eval = True
    output_dimension = 9 # Action will always be 8 dimensional 7 dimension joint angles + 1 dimension gripper

    state_space = "both" # "joint_space", "both", "image_embedding"
    if(state_space == "joint_space"):
        state_dimension =60
    elif(state_space == "both"):
        state_dimension =1024+60 # image , joint
    elif(state_space == "image_embedding"):
        state_dimension = 1024 # See how many dimension is the image embedding????

    Trajectory_directories = ['1.1', '1.2', '1.3', '1.4', '1.5', '2.1', '2.2', '2.3', '2.4', '2.5', '3.1', '3.2', '3.3', '3.4', '3.5', '4.1', '4.2', '4.3', '4.4', '4.5', '5.1', '5.2', '5.3', '5.4', '5.5', '6.1', '6.2', '6.3', '6.4', '6.5', '7.1', '7.2', '7.3', '7.4', '7.5', '8.1', '8.2', '8.3', '8.4', '8.5', '9.1', '9.2', '9.3', '9.4', '9.5', '10.1', '10.2', '10.3', '10.4', '10.5', '11.1', '11.2', '11.3', '11.4', '11.5', '12.1', '12.2', '12.3', '12.4', '12.5', '13.1', '13.2', '13.3', '13.4', '13.5', '14.1', '14.2', '14.3', '14.4', '14.5', '15.1', '15.2', '15.3', '15.4', '15.5', '16.1', '16.2', '16.3', '16.4', '16.5', '17.1', '17.2', '17.3', '17.4', '17.5', '18.1', '18.2', '18.3', '18.4', '18.5', '19.1', '19.2', '19.3', '19.4', '19.5', '20.1', '20.2', '20.3', '20.4', '20.5', '21.1', '21.2', '21.3', '21.4', '21.5', '22.1', '22.2', '22.3', '22.4', '22.5', '23.1', '23.2', '23.3', '23.4', '23.5', '24.1', '24.2', '24.3', '24.4', '24.5', '25.1', '25.2', '25.3', '25.4', '25.5']
    list_of_subgoals_directory_to_eval = ['1.1', '1.2', '1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '1.9', '1.10', '2.1', '2.2', '2.3', '2.4', '2.5', '2.6', '2.7', '2.8', '2.9', '2.10', '3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', '3.10', '4.1', '4.2', '4.3', '4.4', '4.5', '4.6', '4.7', '4.8', '4.9', '4.10', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8', '5.9', '5.10', '6.1', '6.2', '6.3', '6.4', '6.5', '6.6', '6.7', '6.8', '6.9', '6.10', '7.1', '7.2', '7.3', '7.4', '7.5', '7.6', '7.7', '7.8', '7.9', '7.10', '8.1', '8.2', '8.3', '8.4', '8.5', '8.7', '8.8', '8.9', '8.10', '9.1', '9.2', '9.3', '9.4', '9.5', '9.6', '9.7', '9.8', '9.9', '9.10', '10.1', '10.2', '10.3', '10.4', '10.5', '10.6', '10.7', '10.8', '10.9', '10.10', '11.1', '11.2', '11.3', '11.4', '11.5', '11.6', '11.7', '11.8', '11.9', '11.10', '12.1', '12.2', '12.3', '12.4', '12.5', '12.6', '12.7', '12.8', '12.9', '12.10', '13.1', '13.2', '13.3', '13.4', '13.5', '13.6', '13.7', '13.8', '13.9', '13.10', '14.1', '14.2', '14.3', '14.4', '14.5', '14.6', '14.7', '14.8', '14.9', '14.10', '15.1', '15.2', '15.3', '15.4', '15.5', '15.6', '15.7', '15.8', '15.9', '15.10', '16.1', '16.2', '16.3', '16.4', '16.5', '16.6', '16.7', '16.8', '16.9', '16.10', '17.1', '17.2', '17.3', '17.4', '17.5', '17.6', '17.7', '17.8', '17.9', '17.10', '18.1', '18.2', '18.3', '18.4', '18.5', '18.6', '18.7', '18.8', '18.9', '18.10', '19.1', '19.2', '19.3', '19.4', '19.5', '19.6', '19.7', '19.8', '19.9', '19.10', '20.1', '20.2', '20.3', '20.4', '20.5', '20.6', '20.7', '20.8', '20.9', '20.10', '21.1', '21.2', '21.3', '21.4', '21.5', '21.6', '21.7', '21.8', '21.9', '21.10', '22.1', '22.2', '22.3', '22.4', '22.5', '22.6', '22.7', '22.8', '22.9', '22.10', '23.1', '23.2', '23.3', '23.4', '23.5', '23.6', '23.7', '23.8', '23.9', '23.10', '24.1', '24.2', '24.3', '24.4', '24.5', '24.6', '24.7', '24.8', '24.9', '24.10', '25.1', '25.2', '25.3', '25.4', '25.5', '25.6', '25.7', '25.8', '25.9', '25.10']
    total_number_of_iterations=1 # number of iterations per task (helpful with gaussian noise)

    num_epochs = 1000 # number of epochs on the training dataset
    dropout=0.0
    context_length = 1 # This is the maximum past history going in the network. context_length = 1 means no history is passed, only current state is given as input
    action_chunking = 10 # Action chunking = 1 means only 1 step prediction
    gaussian_noise = False # gaussian noise on the start state
    lr =  0.0003 # learning rate
    nhead=4 # number of attention heads. State dimension must be divisible by attention heads
    num_encoder_layers=4 # Number of encoder layers
    dim_feedforward=1024 # dimension of feedforward network
    camera_number = 2 # Camera for subgoals
    temporal_ensemble = 0 # Weight given for combining actions
    in_hand_eval = False # Get in hand camera video or not
    output_dimension*=action_chunking

    subgoal_conditioned = True # goal conditioned subgoals
    subgoal_directory_path = None
    subgoal_change_format = None # How to change subgoal using epsilon or neural nets
    subgoal_frames_delta = 0 # This is number of frames before the subgoal we consider the subgoal as achieved

    if(subgoal_conditioned):
        subgoal_change_format = "same_network" # same_network, epsilon. This tells how do we detect whether a subgoal is achieved or not during inference, "same_network" is ST-GPT
        context_length+=1 # For subgoal conditioned add one more elem for goal
        subgoal_directory_path = f"decomposed_frames/mininterval_18/divisions_1/gamma_0.08/camera_{camera_number}" # Can change 0.08 to something else if required

    saving_formatter = str(find_largest_number("./Parameter_mappings.txt")+1)

    with open('./Parameter_mappings.txt', 'a') as file: 
        file.write(f'{saving_formatter}        : state_space_{state_space}_num_epochs_{num_epochs}_lr_{lr}_dropout_{dropout}_context_length_{context_length}_nhead_{nhead}_num_encoder_layers_{num_encoder_layers}_dim_feedforward_{dim_feedforward}_camera_{camera_number}_gaussian_noise_{gaussian_noise}_subgoal_change_format_{subgoal_change_format}_subgoal_frames_delta_{subgoal_frames_delta}_action_chunking_{action_chunking}_temporal_ensemble_{temporal_ensemble}_Training_directory_{Trajectory_directories}\n')  # Add a newline character to separate lines
    model_dump_file_path = f"./Trained_Models/{saving_formatter}.pth"
    base_directory = f"./../../Data_Franka_Kitchen"

    if(subgoal_conditioned):
        if(subgoal_change_format == "same_network"):
            model = TransformerPolicy_Custom(state_dimension,output_dimension+1,context_length, nhead, num_encoder_layers, dim_feedforward, dropout).to(device)
        elif(subgoal_change_format == "epsilon"):
            model = TransformerPolicy_Custom(state_dimension,output_dimension,context_length, nhead, num_encoder_layers, dim_feedforward, dropout).to(device)
    else:
        model = TransformerPolicy_Custom(state_dimension,output_dimension,context_length, nhead, num_encoder_layers, dim_feedforward, dropout).to(device)

    print(model, flush=True)
    total_params = sum(p.numel() for p in model.parameters())
    print("Total number of parameters in the neural network is: ", total_params, flush=True)

    if(train):
        train_start_time = time.time()
        trajectory_dataset = TrajectoryDataset( Trajectory_directories, base_directory , state_space, subgoal_directory_path, camera_number, action_chunking,subgoal_conditioned, subgoal_frames_delta, subgoal_change_format ,context_length )
        data_loader = DataLoader(trajectory_dataset, batch_size=512, shuffle=True )
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) # cosine decay of learning rate
        if(subgoal_conditioned and subgoal_change_format == "same_network"): # add a new loss function here
            loss_function = BCEWithLogitsLoss_MSELoss() # This is a custom loss function defined above for only this type of network
        else: # if not subgoal conditioned or if subgoal change format is epsilon or different network
            loss_function = torch.nn.MSELoss()

        for epoch in range(num_epochs): # Supervised Learning Loop
            model.train()  # Set the model to training mode
            running_loss = 0.0  # Initialize running loss for the epoch
            num_batches = 0     # Initialize batch counter
            for batch_idx, ( states, actions, timestamps) in enumerate(data_loader):
                states = states.to(device)            # Shape: (batch_size, state_dim)
                actions = actions.to(device)          # Shape: (batch_size, action_dim)
                timestamps = timestamps.to(device)           # Shape: (batch_size,)
                optimizer.zero_grad()
                predicted_actions = model(states, timestamps)
                loss = loss_function(predicted_actions, actions)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()  # Accumulate loss
                num_batches += 1             # Increment batch counter

            scheduler.step()

            if(epoch%50 == 0): # Print loss every 50 epochs
                current_lr = optimizer.param_groups[0]['lr'] # Current learning rate
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/num_batches}, Learning Rate: {current_lr}", flush=True)

        torch.save(model.state_dict(), model_dump_file_path)
        print(f"Model saved to {model_dump_file_path}", flush=True)

        print("Time taken to train the model is ", (time.time() - train_start_time)/3600 , " hrs")

    if(eval):
        length_of_trajectories_during_inference = {}
        reset_info_of_trajectories_during_inference = {} # information about the trajectory to infer
        for directory in list_of_subgoals_directory_to_eval:
            file_path = f"{base_directory}/{directory}/data.pkl" # pkl file path
            with open(file_path, 'rb') as f: # Read the pickel file
                data_dict = pickle.load(f)
                length_of_trajectories_during_inference[directory] = data_dict['observations'].shape[0]
                reset_info_of_trajectories_during_inference[directory] = (data_dict['init_qpos'] , data_dict['init_qvel']) 

        for iteration_number in range(1,total_number_of_iterations+1,1): # Number of times to evaluate a single trajectory, to get the evaluation metrics
            for directory_for_subgoals in list_of_subgoals_directory_to_eval: # Trajectory_directories: # These are all the trajectories to get subgoals from and evaluate
                video_resolution = (224, 224) # This is during evaluation
                reset_information  = reset_info_of_trajectories_during_inference[directory_for_subgoals]
                robot_env = Robotic_Environment(video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval)

                def robot_inference(directory_for_subgoals): # Function to actually evaluate the neural network
                    max_steps = length_of_trajectories_during_inference[directory_for_subgoals]
                    if(subgoal_conditioned):
                        subgoals_directory = f"{base_directory}/{directory_for_subgoals}/{subgoal_directory_path}"
                        files = os.listdir(subgoals_directory)
                        png_files = [f for f in files if f.endswith('.png')]
                        numbers = [int(f.replace('.png', '')) for f in png_files]
                        list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
                        list_of_subgoals.pop(0) # Dont want initial state to be a subgoal
                        actual_subgoals = [] # This is either 8 or 4 or 1024 dimensional
                        if(state_space == "joint_space"):
                            file_path = f"{base_directory}/{directory_for_subgoals}/data.pkl"
                            with open(file_path, 'rb') as f: # Read the pickel file
                                data_dict = pickle.load(f)
                            observations = data_dict['observations']  # Shape: (244, 60)
                            actions = data_dict['actions']  # Shape: (244, 9)
                            data = []
                            for i in range(observations.shape[0]):
                                observation = observations[i]
                                action = actions[i]
                                row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
                                data.append(row) # data now in csv format
                            for subgoal_index in list_of_subgoals:
                                actual_subgoals.append(data[subgoal_index][:60])
                        elif(state_space == "both"):
                            video_path = f"{base_directory}/{directory_for_subgoals}/camera_{camera_number}.avi"
                            cap = cv2.VideoCapture(video_path)
                            for subgoal_index in list_of_subgoals:
                                cap.set(cv2.CAP_PROP_POS_FRAMES, subgoal_index)
                                ret, frame = cap.read()
                                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                                preprocessed_image = preprocessed_image.to(device)
                                with torch.no_grad():
                                    subgoal_embedding = vip(preprocessed_image * 255.0)
                                actual_subgoals.append(subgoal_embedding.cpu().tolist()[0])
                            cap.release()
                            file_path = f"{base_directory}/{directory_for_subgoals}/data.pkl"
                            with open(file_path, 'rb') as f: # Read the pickel file
                                data_dict = pickle.load(f)
                            observations = data_dict['observations']  # Shape: (244, 60)
                            actions = data_dict['actions']  # Shape: (244, 9)
                            data = []
                            for i in range(observations.shape[0]):
                                observation = observations[i]
                                action = actions[i]
                                row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
                                data.append(row) # data now in csv format
                            for iterator in range(len(list_of_subgoals)):
                                subgoal_index = list_of_subgoals[iterator]
                                actual_subgoals[iterator] += data[subgoal_index][:60] # Add the 
                        elif(state_space == "image_embedding"):
                            video_path = f"{base_directory}/{directory_for_subgoals}/camera_{camera_number}.avi"
                            cap = cv2.VideoCapture(video_path)
                            for subgoal_index in list_of_subgoals:
                                cap.set(cv2.CAP_PROP_POS_FRAMES, subgoal_index)
                                ret, frame = cap.read()
                                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                                preprocessed_image = preprocessed_image.to(device)
                                with torch.no_grad():
                                    subgoal_embedding = vip(preprocessed_image * 255.0)
                                actual_subgoals.append(subgoal_embedding.cpu().tolist()[0])
                            cap.release()
                    model.load_state_dict(torch.load(model_dump_file_path))
                    model.eval()  # Set the model to evaluation mode
                    if(subgoal_conditioned):
                        current_subgoal_index = 0
                        if(subgoal_change_format == "epsilon"): # Define the subgoal change epsilon values for different tasks and state spaces
                            if(state_space == "joint_space"):
                                subgoal_change_epsilon = 0.008
                            elif(state_space == "image_embedding"):
                                subgoal_change_epsilon = 0.01
                            elif(state_space == "both"):
                                subgoal_change_epsilon = 0.005
                    history = [] # This is history of all states till now (within context length)
                    timestamps_history = []
                    for i in range(context_length):
                        history.append(robot_env.get_current_state(state_space))
                        timestamps_history.append(0)
                    if(subgoal_conditioned):
                        history.pop()
                        timestamps_history.pop()

                    Buffer = [[] for _ in range(max_steps + action_chunking)]  # Initialize buffer correctly
                    for i in range(max_steps):
                        if(subgoal_conditioned and (current_subgoal_index == len(actual_subgoals)) ):
                            break
                        if(i%100==0):
                            print(f"Timestamp: {i}/{max_steps}")
                        state = robot_env.get_current_state(state_space)
                        history.pop(0)
                        history.append(state)
                        timestamps_history.pop(0)
                        timestamps_history.append(i)
                        state = history.copy()
                        timestamps_state = timestamps_history.copy()
                        if(subgoal_conditioned):
                            state.append(actual_subgoals[current_subgoal_index])
                            timestamps_state.append(i+1)
                        state_tensor = torch.tensor([state], dtype=torch.float32).to(device)
                        timestamps_state = torch.tensor([timestamps_state], dtype=torch.float32).to(device)
                        with torch.no_grad(): 
                            action = model(state_tensor , timestamps_state)
                        action = action.cpu().tolist()[0]

                        if(subgoal_conditioned and subgoal_change_format == "same_network"):
                            subgoal_achievement_bit = action[0]
                            action = action[1:] # action without the subgoal achievement bit

                        action = np.array(action, dtype='float32')
                        action = action.reshape(action_chunking, output_dimension // action_chunking)  # Reshape into action chunks

                        for j in range(action_chunking): # Add the action chunks to the buffer
                            Buffer[i + j].append(action[j])
                        weights = np.exp(-temporal_ensemble * np.arange(len(Buffer[i])))  # Perform temporal ensemble: weighted average of the actions
                        weights /= weights.sum()  # Normalize weights
                        current_action = np.sum([w * a for w, a in zip(weights, Buffer[i])], axis=0)
                        current_action = current_action.tolist()  # Convert to list before passing to `step`
                        robot_env.step(current_action)

                        if(subgoal_conditioned): # see if the subgoal is achieved then transition to the next one
                            if(subgoal_change_format == "same_network"):
                                if(subgoal_achievement_bit > 0):
                                    print(f"Subgoal number {current_subgoal_index+1} achieved at timestamp {i}")
                                    current_subgoal_index+=1
                            elif(subgoal_change_format == "epsilon"):
                                if ( np.mean((np.array(robot_env.get_current_state(state_space)) - np.array(actual_subgoals[current_subgoal_index]))**2)  < subgoal_change_epsilon):
                                    print(f"Subgoal number {current_subgoal_index} achieved at timestamp {i}")
                                    current_subgoal_index+=1

                print(f"Evaluating subgoals from {directory_for_subgoals}, iteration number {iteration_number}...", flush=True)
                robot_inference(directory_for_subgoals)
                print("------------------------------")

                video_path = f"./Evaluation/{saving_formatter}/subgoals_{directory_for_subgoals}"
                os.makedirs(video_path, exist_ok=True) # Directory to save Evaluation Videos
                video_filename = f"{video_path}/{iteration_number}.mp4"
                video_filename_in_hand = f"{video_path}/{iteration_number}_in_hand.mp4"
                robot_env.save_video(video_filename, video_filename_in_hand)